import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader


# 数据预处理
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),  # 对图像进行随机的crop以后再resize成固定大小
    transforms.RandomRotation(20),  # 随机旋转角度
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.ToTensor()
])

# 读取数据
root = 'image'
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)

# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)

classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)

model = models.vgg16(pretrained=True)
print(model)

# 如果我们想只训练模型的全连接层
for param in model.parameters():
    param.requires_grad = False

# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))

LR = 0.0003
# 定义代价函数
entropy_loss = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), LR)


def train():
    model.train()
    for i, data in enumerate(train_loader):
        # 获得数据和对应的标签
        inputs, labels = data
        # 获得模型预测结果，（64，10）
        out = model(inputs)
        # 交叉熵代价函数out(batch,C),labels(batch)
        loss = entropy_loss(out, labels)
        # 梯度清0
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()


def test():
    model.eval()
    correct = 0
    for i, data in enumerate(test_loader):
        # 获得数据和对应的标签
        inputs, labels = data
        # 获得模型预测结果
        out = model(inputs)
        # 获得最大值，以及最大值所在的位置
        _, predicted = torch.max(out, 1)
        # 预测正确的数量
        correct += (predicted == labels).sum()
    print("Test acc: {0}".format(correct.item() / len(test_dataset)))

    correct = 0
    for i, data in enumerate(train_loader):
        # 获得数据和对应的标签
        inputs, labels = data
        # 获得模型预测结果
        out = model(inputs)
        # 获得最大值，以及最大值所在的位置
        _, predicted = torch.max(out, 1)
        # 预测正确的数量
        correct += (predicted == labels).sum()
    print("Train acc: {0}".format(correct.item() / len(train_dataset)))


for epoch in range(0, 10):
    print('epoch:', epoch)
    train()
    test()

torch.save(model.state_dict(), 'cat_dog_fc.pth')
